package com.googlecode.mapperdao.jdbc

import org.springframework.jdbc.core.JdbcOperations
import org.springframework.jdbc.core.BatchPreparedStatementSetter
import org.springframework.jdbc.core.PreparedStatementCallback
import java.sql.PreparedStatement
import org.springframework.jdbc.core.ColumnMapRowMapper
import org.springframework.jdbc.core.PreparedStatementCreator
import java.sql.Connection
import org.springframework.jdbc.core.SqlParameterValue
import org.springframework.jdbc.core.StatementCreatorUtils
import org.springframework.jdbc.core.SqlTypeValue

/**
 * batch inserts/updates
 *
 * @author kostantinos.kougios
 *
 *         Nov 14, 2012
 */
protected class Batch(jdbc: JdbcOperations)
{
	def batchUpdate(sql: String, args: Array[Array[SqlParameterValue]], options: BatchOptions) = {

		val pss = new BatchPreparedStatementSetter
		{
			def setValues(ps: PreparedStatement, i: Int) {
				var idx = 1
				args(i).foreach {
					arg =>
						StatementCreatorUtils.setParameterValue(ps, idx, SqlTypeValue.TYPE_UNKNOWN, arg)
						idx += 1
				}
			}

			def getBatchSize = args.length
		}
		jdbc.execute(new PreparedStatementCreator
		{
			override def createPreparedStatement(con: Connection) = if (options.autogenerated.length > 0)
				con.prepareStatement(sql, options.autogenerated)
			else
				con.prepareStatement(sql)
		}, new PreparedStatementCallback[BatchResult]
		{
			override def doInPreparedStatement(ps: PreparedStatement) = {
				val (result, keys) = options.strategy.process(ps, pss, options.autogenerated)
				BatchResult(result, keys)
			}
		})
	}
}

case class BatchOptions(strategy: Batch.Strategy, autogenerated: Array[String])

object Batch
{

	private[mapperdao] trait Strategy
	{
		def process(
			ps: PreparedStatement,
			pss: BatchPreparedStatementSetter,
			autogenerated: Array[String]
			): (Array[Int], Array[java.util.Map[String, Object]])
	}

	/**
	 * for drivers that don't support batch updates (or returning generated keys after batch updates)
	 */
	object NoBatch extends Strategy
	{
		def process(ps: PreparedStatement, pss: BatchPreparedStatementSetter, autogenerated: Array[String]) = {
			val sz = pss.getBatchSize
			val updates = new Array[Int](sz)
			val ag = autogenerated.length > 0
			val ka = if (ag) new Array[java.util.Map[String, Object]](sz) else null
			var idx = 0
			for (i <- 0 until sz) {
				pss.setValues(ps, i)
				updates(idx) = ps.executeUpdate
				if (ag) ka(idx) = keys(ps, 1)(0)
				idx += 1
			}
			(updates, ka)
		}
	}

	/**
	 * for drivers that support batch updates and returning generated keys after batch updates
	 */
	object WithBatch extends Strategy
	{
		def process(ps: PreparedStatement, pss: BatchPreparedStatementSetter, autogenerated: Array[String]) = {
			val sz = pss.getBatchSize
			for (i <- 0 until sz) {
				pss.setValues(ps, i)
				ps.addBatch()
			}
			(
				ps.executeBatch,
				if (autogenerated.length > 0) keys(ps, sz) else null
				)
		}
	}

	// utility methods
	private val rm = new ColumnMapRowMapper

	private def keys(ps: PreparedStatement, numOfKeys: Int) = {
		var idx = 0
		val keys = new Array[java.util.Map[String, Object]](numOfKeys)
		val generatedKeys = ps.getGeneratedKeys
		try {
			while (generatedKeys.next) {
				keys(idx) = rm.mapRow(generatedKeys, idx)
				idx += 1
			}
			if (idx != numOfKeys)
				throw new IllegalStateException("expected " + numOfKeys + " autogenerated keys but was " + idx)
			keys
		} finally {
			generatedKeys.close()
		}
	}
}

case class BatchResult(rowsAffected: Array[Int], keys: Array[java.util.Map[String, Object]])
